/*
    Copyright 2006-2011 Patrik Jonsson, sunrise@familjenjonsson.org

    This file is part of Sunrise.

    Sunrise is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 3 of the License, or
    (at your option) any later version.

    Sunrise is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with Sunrise.  If not, see <http://www.gnu.org/licenses/>.

*/

/// \file cuda_grain_temp.cu
/// Contains CUDA temperature and SED calculation functions.

// $Id: grain.h 2341 2009-08-03 21:14:33Z sunrise@familjenjonsson.org $

#include "cuda_grain_temp.h"
#include "math.h"
#include "cuda_runtime_api.h"
#include "cutil.h"
#include <cassert>
#include <stdio.h>
#include <iostream>

#ifdef MCRX_DEBUG_LEVEL
#define DEBUG(level, statement)     \
    if(MCRX_DEBUG_LEVEL>=level) {   \
      statement;                    \
      }
#else
#define DEBUG(level, statement)
#endif

using namespace std;

bool checkerr()
{
  const cudaError_t err=cudaGetLastError();
  if(err!=cudaSuccess)
    cout << "CUDA error: " << cudaGetErrorString(err) << endl;
  assert(err==cudaSuccess);
  return err==cudaSuccess;
}

#define BLOCKSIZE 16
typedef float T_float;
typedef float T_float32;

__constant__ T_float piD, sed_normD;
const T_float pif=4*atanf(1.0f);

// easserts are only used for device emulation mode
//#ifdef __DEVICE_EMULATION__
//#define eassert(x) assert(x)
//#else
#define eassert(x) 
//#endif

/** Calculates the Planck function for a blackbody of temperature T at
    wavelength 1/invlambda. The inverse wavelength and temperature is
    passed for efficiency. */
__device__ T_float B_lambda (T_float invlambda, T_float invT) 
{
  // 2hc^2 in units is 1.19e-16 kgm^4/s^3=Wm^2 
  // heating is in W, so we convert this to [W um^2].

  // to better stay within single-precision range, we set a to 1 and
  // instead divide the heating by 1.19e-16.
  const T_float a = 1.0f;

  // hc/k for the exp function, in units [um K].
  const T_float d = 1.43876866033e-2f;
  // avoiding the expm1f does not seem to make it faster, but the
  // __expf function is much faster without noticeably affecting accuracy.
  const T_float ie = 1.0f/(__expf(d*invlambda*invT)-1.0f); 

  const T_float retval = a*invlambda*invlambda*invlambda*invlambda*invlambda*ie;

  eassert((invlambda==0) || (retval==retval));
  eassert((invlambda==0) || (retval>=0));
  eassert((invlambda==0) || (retval<HUGE_VAL));

  return retval;
}


/** Calculates the inverse-temperature derivative of the Planck
    function for a blackbody of temperature 1/invT at wavelength
    1/invlambda. The inverse wavelength and temperature is passed for
    efficiency. */
__device__ 
T_float dB_lambda_dT (T_float invlambda, T_float invT) 
{

  // hc/k for the exp function, in units [um K].
  const T_float d = 1.43876866033e-2f;

  // to better stay within single-precision range, we set a (2hc^2 in
  // units [W um^2]) to 1 and
  // instead divide the heating by 1.19e-16.
  const T_float a = 1.0f;

  const T_float ad = a*d; //1.19104393407e-16f * 1.43876866033e-2f;
  // save the exponent expression. __expf is a lot faster than using
  // expm1f, but __fdividef does not seem to make a difference.
  const T_float iem1 = 1.0f/(__expf (d*invlambda*invT)-1);

  // we should have e/em1^2, but it is ill behaved when e is
  // inf. Instead we use (1/em1 + 1/em1^2) which is the same thing.
  const T_float retval= ad*(iem1+iem1*iem1)*invT*invT*
    invlambda*invlambda*invlambda*invlambda*invlambda*invlambda;

  eassert(retval==retval);
  eassert(abs(retval)<HUGE_VAL);
  return retval;
}


/** The function definition for the NRsolve Newton-Raphson solver in
    NRsolve. */
__device__ 
T_float func (T_float T, T_float heating, 
	      T_float* sigmaS, T_float* lambdaS, 
	      size_t nl) 
{
  const T_float invT=1.0f/T;
  // we don't init the sum with heating because that would lead to
  // lower numerical accuracy (don't start a sum with a large number).
  T_float f = 0.0f;
  for (size_t l=0; l<nl; ++l) {
    f -= sigmaS[l] * B_lambda (__fdividef(1.0f,lambdaS[l]), invT);
  }

  f+=heating;
  eassert(f==f);
  return f;
};


/** The derivative definition for the NRsolve Newton-Raphson solver in
    NRsolve. */
__device__ 
T_float der (T_float T, T_float* sigmaS, T_float* lambdaS, 
	      size_t nl) 
{
  const T_float invT=1./T;    
  T_float d = 0.0f;
  for (size_t l=0; l<nl; ++l) {
    d -= sigmaS[l] * dB_lambda_dT (__fdividef(1.0f,lambdaS[l]), invT);
  }
  eassert(d==d);
  return d;
}


/** The NR solver. sigmaS and lambdaS are shared memory arrays with
    the sigma and lambda for the current calculation. The function
    to be solved and its derivative are defined in func and der. */
__device__ 
T_float NRsolve(T_float x0, T_float accuracy, T_float heating, 
		T_float* sigmaS, T_float* lambdaS, size_t nl)
{
  const int maxsteps = 50;

  T_float f;
  T_float x = x0;
  int Nstep = 0;
  T_float fp;
  while( (fabs(f=func(x, heating, sigmaS, lambdaS, nl))>accuracy) ) {
    Nstep++;

    fp = der(x, sigmaS, lambdaS, nl);
    if(fp==0.0f) {
      // extremum, return NaN to indicate that solution is crap
      x=0.0f/0.0f;
      break;
    }
     else
      x-= f/fp;
    eassert(x==x);
    eassert(x<1e4);
    eassert(x>0);
    x0=x;

    if(Nstep>=maxsteps) {
      // no convergence, return NaN.
      x=0.0f/0.0f;
      break;
    }
  }
  return x;
}



 // index functions for the shared mem arrays
#define CHECK_BANK_CONFLICTS 0
#if CHECK_BANK_CONFLICTS
#define SIGMA(r,c) CUT_BANK_CHECKER((sigmaS), \
				     (BLOCKSIZE * (r) + (c)))
#define INTENSITY(r, c) CUT_BANK_CHECKER((intensityS), \
					 (BLOCKSIZE * (r) + (c)))
#define TEMP(r, c) CUT_BANK_CHECKER((tempS), \
					 (BLOCKSIZE * (r) + (c)))
#else
// using __umul24 here makes it slower
#define SIGMA(r, c) sigmaS[(r)*(BLOCKSIZE+1)+(c)]
#define INTENSITY(r, c) intensityS[(r)*(BLOCKSIZE+1)+(c)]
#define TEMP(r, c) tempS[(r)*(BLOCKSIZE+1)+(c)]
#endif


// the general pointer to the shared memory
extern __shared__ char sharedmem[];

/* Calculating heating is essentially a matrix multiplication of
   sigmadlambda and intensity. The code is largely based on the cuda
   matrixmul example. Arrays are stored C-style (row-major, in (r,c)
   format) and are heating(species,cell), sigma(species, lambda), and
   intensity(cell, lambda). The thread processes heating(species,
   cell). The pitch values should be in index units, not memory
   addresses. If the intensity array is subsampled, the subsampling
   factor should be specified, and if the intensity array are doubles,
   then it is converted when staging into shared memory. Actually this
   function calculates heating/4pi, we ignore the 4pi because it's
   compensated for by a similar 4pi in the emission calculation, but
   it makes a difference when comparing physical heating values. */
template<typename T_input>
__global__ void
calculate_heating( T_float* heating, size_t heatingP,
		   const T_float* sigmadlambda, size_t sigmaP,
		   const T_input* intensity, size_t intensityP,
		   size_t ns, size_t nc, size_t nl)
{
  const size_t bsl = BLOCKSIZE;

  // Block index (x and y is confusing because they are ordered differently)
  const size_t bc = blockIdx.x;
  const size_t bs = blockIdx.y;
  
  // Thread index
  const size_t tc = threadIdx.x;
  const size_t ts = threadIdx.y;

  // block size. NOTE THAT bsx=bsy=bsl, otherwise the staging will not work!
  const size_t bsc = blockDim.x;
  const size_t bss = blockDim.y;

  // extract the proper shared memory arrays:
  // total smem required bsl*(bsx+bsy+1)*sizeof(T_float) bytes
  // sub-matrix of sigmadlambda, size (bss, bsl+1)
  T_float* sigmaS = (T_float*)sharedmem;
  
  // sub-matrix of intensity, size (bsc, bsl+1) (+1 to avoid bank conflicts)
  T_float* intensityS = sigmaS + bss*(bsl+1);
  
  // Each block processes a sub-matrix of heating starting with cell
  // bc*bsc and species bs*bss, of size bsc*bsr.

  // These are made up of parts of sigma and intensity, which we stage
  // in shared memory in chunks of bsl wavelengths at a time.

  // Index of the (first element of the) first sub-matrix of sigma
  // processed by the block. This is sigma(0, by*bsize). Remember to
  // use the pitch.

  // Step size used to iterate through the sub-matrices of
  // sigma. Because lambda is the lowest stride index, this is just
  // bsl.
  
  // Index of the (first element of the) first sub-matrix of intensity
  // processed by the block. This is intensity(0, bx*bsize). Remember
  // to use the pitch.

  // Step size used to iterate through the sub-matrices of
  // intensity. Because lambda is the lowest stride index of this
  // variable too, it's also just bsl.
  
  // h is used to store the element of heating that is computed by the
  // thread.
  T_float h = 0.0f;

  // Loop over all the lambda blocks. We will process nl values.
  for (size_t lblock=0;
       lblock<nl; 
       lblock+=bsl) {
    
    // Load the blocks from device memory to shared memory; each
    // thread loads one element of each matrix. 

    // To avoid bank conflicts in the computation loop, we need to
    // put lambda in the row-index of the shared memory arrays
    // we get bank conflicts here instead but that's only 1/16th the time.
    
    // load sigmadlambda(lambda, size).
    SIGMA(tc,ts) = T_float(sigmadlambda[ (bs * bss + ts) * sigmaP + 
					 lblock +  tc]);

    // Load (c, lambda) of intensity (note that c is indexed by *ts*
    // here). We deal with subsampling by simply loading the same value
    // into several lambdas. For efficiency (and to keep register
    // usage low) we use a shift instead of a division.
    INTENSITY(tc,ts) = T_float(intensity[ (bc * bsc + ts) * intensityP + 
					  (lblock + tc) ]);

    eassert((SIGMA(tc,ts)==SIGMA(tc,ts)));
    eassert((INTENSITY(tc,ts)==INTENSITY(tc,ts)));

    // Synchronize to make sure the matrices are loaded
    __syncthreads();

    // Now calculate the amount of heating resulting from the
    // wavelengths in this block sub-matrix. Because dlambdaS is zero
    // for wavelengths outside the boundary, we don't have to check
    // the length of the lambda array.
    for (size_t l = 0; l < bsl; ++l) {
      eassert(SIGMA(ts,l)==SIGMA(ts,l));
      eassert(INTENSITY(tc,l)==INTENSITY(tc,l));
      h += SIGMA(l,ts) * INTENSITY(l,tc);
      eassert(h==h);
      eassert(h>=0);
    }

    // Synchronize to make sure that the preceding
    // computation is done before loading two new
    // sub-matrices of A and B in the next iteration
    __syncthreads();
  }
  
  // Write the heating block sub-matrix to device memory; each thread
  // writes one element. Because both intensity and sigma in the
  // padded area are set to zero, we don't need to check that we are
  // within bounds, the heating rate in the padded area will be zero.
  // Threads execute in x-warps (ie cells), so to get coalesced stores
  // we need to have this be cell-minor

  // this number a was divided out from the blackbody function so we
  // divide it out here too to better stay in single-prec range
  const T_float inva = 1.0f/1.19104393407e-16f; 

  heating[ (bs*bss+ts) * heatingP + bsc * bc + tc] = h*inva;
}


/* For the calculation of heating and temperature we don't need sigma
   by itself but rather sigma*dlambda, so this kernel calculates that
   quantity (for both sigma and esigma). Each thread processes one
   element of sigma(species, lambda). If the threads access
   consecutive lambda, the loads/stores will be coalesced, so we just
   need the blockdim to be >16 in x. */
template <typename T_input>
__global__ void
calculate_sigmadlambda( const T_input* sigma, size_t sigmaP,
			const T_input* lambda, 
			T_float* sigmadlambda, size_t sdlP,
			size_t nl)
{
  // the element we are processing
  const size_t s=blockDim.y*blockIdx.y+threadIdx.y;
  const size_t l=blockDim.x*blockIdx.x+threadIdx.x;

  // because we are constructing delta lambda, we need to worry about
  // the edge cases
  T_input dlambda;
  if(l==0)
    dlambda = lambda[l+1]-lambda[l];
  else if(l>=nl-1)
    dlambda = lambda[l]-lambda[l-1];
  else
    dlambda = lambda[l+1]-lambda[l-1];

  sigmadlambda[s*sdlP+l] = T_float(sigma[s*sigmaP+l]*0.5*dlambda);
  eassert(sigmadlambda[s*sdlP+l]>0);
}

/* Calculates the temperature of the grains through NR solving. Each
   thread calculates T(species, cell). For each iteration, we need the
   lambda and sigma vectors for the species, so we stage those in
   shared mem each block is for a specific species. */
template <typename T_input>
__global__ void
calculate_temp( T_float accuracy, 
		T_float* temp, size_t tempP,
		const T_float* heating, size_t heatingP,
		const T_float* sigmadlambda, size_t sigmaP,
		const T_input* lambda, const T_input* size,
		size_t ns, size_t nc, size_t nl)
{
  // the element we are processing
  const size_t s=blockDim.y*blockIdx.y+threadIdx.y;
  const size_t c=blockDim.x*blockIdx.x+threadIdx.x;

  // block size. bss should be one, which also means no threads will
  // execute outside of ns, so we don't need to check that.
  const size_t bsc = blockDim.x;
  //eassert(blockDim.y==1);

  // Thread index
  const size_t tc = threadIdx.x;
  eassert(threadIdx.y==0);

  // extract the proper shared memory arrays:
  // total smem required 2*nl
  // sigma for species s, size (1,nl)
  T_float* sigmaS = (T_float*)sharedmem;
  
  // lambda, size (nl). There is no padding issue with lambda, because
  // it's not one of the thread parameters, it's just a loop in each
  // thread.
  T_float* lambdaS = sigmaS + nl;

  T_float* sizeS = lambdaS + nl;

  // load the shared memory arrays. each thread loads one element and
  // we loop until we have all elements. 
  for (size_t lblock=0; lblock<nl; lblock+=bsc) {
    if(lblock+tc<nl) {
      sigmaS[lblock+tc] = T_float(sigmadlambda [s*sigmaP + lblock+ tc]);
      lambdaS[lblock+tc] = T_float(lambda [lblock + tc]);
    }
    if(lblock+tc<ns) {
      sizeS[lblock+tc] = T_float(size[lblock+tc]);
    }
  }

  __syncthreads();

  // We must check that we are within bounds because the block size is
  // 16 but the grid size is 256, so the assignment will overrun the
  // temp array. The nc passed is actually ncPadded, so we *will* zero
  // out the padded area.

  if(c<nc) {
    // Get heating
    const T_float h = heating[s*heatingP + c]; 

    // because the solver uses 1/T, it will go crazy for zero temp. We to check
    // explicitly for zero heating, which can happen if no rays hit
    // the cell. (Or in the padded area.)
    T_float T;

    if(h>0.0f) {

      // We get the initial guess from our fit.  Remember that we
      // rescaled the heating
      T_float xx = log10(h*1.19104393407e-16f) -
	1.39f*log10(sizeS[s]) + 
	0.126f*pow(log10(sizeS[s]),2);
      T_float T0 =
	pow(10.0f,1.86f+0.189f*xx+3.41e-3f*xx*xx);
      T0 = (T0<5.0f)?5.0f:T0;

      eassert(T0==T0);
      eassert(T0>0);
      T = NRsolve( T0, accuracy*h, h, sigmaS, lambdaS, nl);
    }
    else
      T=0.0f;

    eassert(T==T);
    eassert(T>=0);

    temp[s*tempP+c] = T;
  }
};



/* Calculating the SED is also essentially a matrix multiplication of
   sigma and B_lambda(temp), with the added lambda vector as an
   argument to B_lambda. Arrays are sed(cell, lambda), sigma(species,
   lambda), and temp(species, cell). Each thread calculates
   sed(cell, lambda). Like calculate_heating, each block stages blocks
   in shared mem. */
template <typename T_input>
__global__ void
calculate_sed ( T_input* sed, size_t sedP,
		const T_float* temp, size_t tempP,
		const T_input* sigma, size_t sigmaP,
		const T_input* lambda, const T_input* mdust, const T_input* dn,
		size_t ns, size_t nc, size_t nl,
		bool add_to_sed
#ifdef __DEVICE_EMULATION__
		,size_t nltrue
#endif
		)
{
  const size_t bss = BLOCKSIZE;

  // Block index (x and y is confusing because they are ordered differently)
  const size_t bl = blockIdx.x;
  const size_t bc = blockIdx.y; // can't be short bc too many cells
  
  // Thread index
  const size_t tl = threadIdx.x;
  const size_t tc = threadIdx.y;

  // block size. NOTE THAT bsx=bsy=bsl, otherwise the staging will not work!
  const size_t bsl = blockDim.x;
  const size_t bsc = blockDim.y;

  // extract the proper shared memory arrays:
  // total smem required (bsl+bsc)*(bss+1)+bsl+bsc+bss
  // = 2*bs*(bs+1)+3*bs = bs*(3+2*(bs+1))

  // sub-matrix of sigma, size (bsl, bss+1)
  T_float* sigmaS = (T_float*)sharedmem;
  
  // sub-matrix of (1/temp), size (bsc, bss+1) (+1 to avoid bank conflicts)
  T_float* tempS = sigmaS + bsl*(bss+1);
  
  // sub-vector of lambda, size (bsl)
  T_float* lambdaS = tempS + bsc*(bss+1);

  // sub-vector of mdust, size (bsc)
  T_float* mdustS = lambdaS + bsl;

  // sub-vector of dn, size (bss)
  T_float* dnS = mdustS + bsc;

  // Each block processes a sub-matrix of sed starting with cell
  // bc*bsc and lambda bl*bsl, of size bsc*bsl.

  // These are made up of parts of sigma and temp, which we stage
  // in shared memory in chunks of bss species at a time.

  // Index of the (first element of the) first sub-matrix of sigma
  // processed by the block. This is sigma(0, bl*bsl). Remember to
  // use the pitch.

  // Step size used to iterate through the sub-matrices of
  // sigma is sigmaP, because species is highest stride.

  // Index of the (first element of the) first sub-matrix of temp
  // processed by the block. This is temp(0, bc*bsc). Remember
  // to use the pitch.

  // Step size used to iterate through the sub-matrices of
  // temp is tempP, because species is the highest stride.
  
  // sedEl is used to store the element of sed that is computed by the
  // thread.
  T_float sedEl = 0;

  // l can be >=nl in which case we will write crap but that's in the
  // padding anyway
  const size_t l=bl*bsl+tl; // l (or tl) is unit thread num

  // We can stage mdust and lambda immediately, because they don't
  // depend on the blocking in the s-dir. Indexed by tl for coalescence.
  // (threads with same tl duplicate load but that is of no consequence)
  mdustS[tl] = T_float(mdust[bc*bsc+tl]);

  lambdaS[tl] = T_float(lambda[ l ]);

#ifdef __DEVICE_EMULATION__
   // if we are emulating, this must be done or the NaN assertions
   // fail in B_lambda for the outside-boundary elements.
   if (l>=nltrue)
     lambdaS[tl]=1.0f;
#endif
  // Loop over all the species blocks. We will process ns values.
  for (size_t sblock=0;
       sblock<ns; 
       sblock+=bss) {
    
    // Load the blocks from device memory to shared memory; each
    // thread loads one element of each matrix. c-index is s during
    // this load.
    
    // load sigma, indexed by (c-index, l-index)
    SIGMA(tc,tl) = T_float(sigma[ sblock*sigmaP + tc*sigmaP  + l ]);

    // load temp, indexed by (c-index, l-index). Note c is indexed
    // by *tl*. NOTE that we stage 1/temp because that's what's used.
    TEMP(tc,tl) = 1.0f/temp[ sblock* tempP + tc*tempP + bc * bsc + tl ];

    // load dn, indexed by tl (for coalesced load). threads with same
    // tl will duplicate load but that is of no consequence.
    dnS[tl] = T_float(dn[sblock+tl]);

    eassert(lambdaS[tl]==lambdaS[tl]);
    eassert((SIGMA(tc,tl)==SIGMA(tc,tl)));
    eassert((TEMP(tc,tl)==TEMP(tc,tl)));

    // Synchronize to make sure the matrices are loaded
    __syncthreads();

    // Now calculate the sed contribution from the species in this
    // block sub-matrix. For species outside the s-boundary, sigmaS is
    // zero, and temp is zero too from the earlier calculations, so we
    // don't have to check the length of the array. If we are outside
    // lambda array, we get zero lambdas which give Nan from B_lambda,
    // so sedEl will be NaN. Thats's ok, except that it screws with
    // the assertions in the B_lambda function.
    for (size_t s = 0; s < bss; ++s) {
      eassert(SIGMA(s,tl)==SIGMA(s,tl));
      eassert(TEMP(s,tc)==TEMP(s,tc));
      // if temp is zero we can't call the function because we get
      // divide by zero. in that case we just continue.
      // XXX we have 1/temp in that variable!
      if(TEMP(s,tc)>1e10)
	continue;
      sedEl += 
	SIGMA(s,tl) * 
	B_lambda(1.0f/lambdaS[tl], TEMP(s,tc)) *
	dnS[s];
      eassert((l>=nltrue) || (sedEl==sedEl));
    }

    // Synchronize to make sure that the preceding
    // computation is done before loading two new
    // sub-matrices of A and B in the next iteration
    __syncthreads();
  }
  
  // Write the sed block sub-matrix to device memory; each thread
  // writes one element. If we are asked to add to the SED, then we do
  // so (which will be slightly slower because we will also load).
  // Also put the scale factor back from the B_lambda rescaling.
  const T_input a = 4.0*piD*1.19104393407e-16*sed_normD; 
  if(add_to_sed)
    sed[ (bc*bsc+tc) * sedP + l ] += T_input(sedEl)*a *mdustS[tc];
  else
    sed[ (bc*bsc+tc) * sedP + l ] = T_input(sedEl)*a *mdustS[tc];

  eassert((l>=nltrue) || (sed[ (bc*bsc+tc) * sedP + l ]==sed[ (bc*bsc+tc) * sedP + l ]));
}


template<typename T_input>
unsigned int mcrxcuda::calculate_equilibrium_SED(const T_input* sigma, 
						 const T_input* esigma, 
						 const T_input* intensity, 
						 const T_input* lambda, 
						 const T_input* elambda, 
						 const T_input* size,
						 const T_input* m_dust, 
						 const T_input* dn,
						 T_input* sed,
						 size_t ns, size_t nc, 
						 size_t nl, size_t nel, 
						 bool add_to_sed, 
						 T_float accuracy,
						 T_float sed_norm,
						 T_float* heating, T_float* temp)
{
  // get device memory
  int cuda_dev;
  cudaGetDevice(&cuda_dev);
  cudaDeviceProp dp;
  cudaGetDeviceProperties(&dp, cuda_dev);
  const size_t devmem = dp.totalGlobalMem;

  // padded sizes to multiple of BLOCKSIZE
  const size_t nsPadded = (ns%BLOCKSIZE) ? (ns/BLOCKSIZE+1)*BLOCKSIZE : ns;
  const size_t ncPadded = (nc%BLOCKSIZE) ? (nc/BLOCKSIZE+1)*BLOCKSIZE : nc;
  const size_t nlPadded = (nl%BLOCKSIZE) ? (nl/BLOCKSIZE+1)*BLOCKSIZE : nl;
  const size_t nelPadded = (nel%BLOCKSIZE) ? (nel/BLOCKSIZE+1)*BLOCKSIZE : nel;

  //XXX FIXME for esigma
  const size_t estmem_sigma = sizeof(T_input)*nlPadded*nsPadded;
  const size_t estmem_esigma = sizeof(T_input)*nelPadded*nsPadded;
  const size_t estmem_sed = sizeof(T_input)*nelPadded*ncPadded;
  const size_t estmem_int = 
    (sizeof(T_input)*nlPadded*ncPadded);
  // heating, temp
  const size_t estmem_a2 = sizeof(float)*ncPadded*nsPadded; 
  // kernel 2 uses sigma, int., heating
  const size_t estmem2 = estmem_sigma + estmem_int + estmem_a2; 

  // kernel 3 uses sigma, heating, temp
  const size_t estmem3 = estmem_esigma + 2*estmem_a2; 
  // kernel 4 uses (sigma, temp, sed)
  const size_t estmem4 = estmem_esigma + estmem_a2 + estmem_sed; 
  const size_t estmem = std::max(std::max(estmem2,estmem3),estmem4);

  DEBUG(1,cout << "Estimated % global memory needed: " << 100.0*estmem/devmem << endl;);

  unsigned int totalTimer=0;
  cutCreateTimer(&totalTimer);
  cutStartTimer(totalTimer);

  // copy constants to constant mem
  cudaMemcpyToSymbol(piD,&pif,sizeof(T_float32),0,cudaMemcpyHostToDevice);
  cudaMemcpyToSymbol(sed_normD,&sed_norm,sizeof(T_float32),0,
		     cudaMemcpyHostToDevice);

  // all arrays are padded to even multiples of the block size so we
  // don't have to worry about threads that are out of bounds

  // *** allocate and copy data not dependent on cells to device
  // *** memory

  T_input *sigmaD=0, *esigmaD=0, *lambdaD=0, *elambdaD=0, *sizeD=0, *dnD=0;
  size_t sigmaP, esigmaP;

  // For the lambda values, we don't blank out with a value of 0,
  // because that leads to a division by zero in the sed kernel which
  // uses 1/lambda. Any value will do, as long as it's not 0 or NaN.
  cudaMalloc((void**)&lambdaD, sizeof(T_input)*nlPadded);
  cudaMemset(lambdaD, 1, sizeof(T_input)*nlPadded);
  cudaMemcpy(lambdaD, lambda, sizeof(T_input)*nl, cudaMemcpyHostToDevice);
  cudaMalloc((void**)&elambdaD, sizeof(T_input)*nelPadded);
  cudaMemset(elambdaD, 1, sizeof(T_input)*nelPadded); 
  cudaMemcpy(elambdaD, elambda, sizeof(T_input)*nel, cudaMemcpyHostToDevice);

  cudaMalloc((void**)&sizeD, sizeof(T_input)*nsPadded);
  // to blank out padding (if it's 0 we get nan in the padded area)
  cudaMemset(sizeD, 0, sizeof(T_input)*nsPadded);
  cudaMemcpy(sizeD, size, sizeof(T_input)*ns, cudaMemcpyHostToDevice);

  // because max pitch in cudamallocpitch is too small for blocks that
  // have number of cells in the lowest stride direction, we don't use
  // it for the 2d arrays unless we have to copy the data back to the
  // host. The block size should already be padded such that the
  // blocks are aligned. This affects heating and temp arrays.

  cudaMallocPitch((void**)&sigmaD, &sigmaP, sizeof(T_input)*nlPadded, nsPadded);
  // Blanks out the array so the padded area contains zero. That way
  // we can ignore bounds when calculating without risk getting NaN.
  // Do NOT delete this.
  cudaMemset2D(sigmaD, sigmaP, 0, sizeof(T_input)*nlPadded, nsPadded);
  assert(sigmaP%sizeof(T_input)==0);
  cudaMemcpy2D(sigmaD, sigmaP, sigma, sizeof(T_input)*nl, 
	       sizeof(T_input)*nl, ns, cudaMemcpyHostToDevice);

  cudaMallocPitch((void**)&esigmaD, &esigmaP, sizeof(T_input)*nelPadded, nsPadded);
  cudaMemset2D(esigmaD, esigmaP, 0, sizeof(T_input)*nelPadded, nsPadded);
  cudaMemcpy2D(esigmaD, esigmaP, esigma, sizeof(T_input)*nel, 
	       sizeof(T_input)*nel, ns, cudaMemcpyHostToDevice);

  cudaMalloc((void**)&dnD, sizeof(T_input)*nsPadded);
  cudaMemcpy(dnD, dn, sizeof(T_input)*ns, cudaMemcpyHostToDevice);

  T_float32 *sigmadlambdaD=0, *esigmadlambdaD=0;
  size_t sdlP, esdlP;
  // as of cuda v3.2, we apparently get different pitches for the 4-
  // and 8-byte arrays, so we must supply all pitches.
  cudaMallocPitch((void**)&sigmadlambdaD, &sdlP, 
		  sizeof(T_float32)*nlPadded, nsPadded);
  cudaMallocPitch((void**)&esigmadlambdaD, &esdlP,
		  sizeof(T_float32)*nelPadded, nsPadded);
  checkerr();

  // first call kernel 1 to calculate sigma*dlambda and esigma*delambda

  {    
    dim3 threads(16, 8);
    // thread x is lambda, y is species
    dim3 grid(nlPadded/threads.x, nsPadded/threads.y);
    DEBUG(1,cout << "Calling sigmadlambda kernel (" \
	  << threads.x << ',' << threads.y \
	  << ") (" << grid.x << ',' << grid.y  << ") " << endl;);
    unsigned int timer=0;
    cutCreateTimer(&timer);
    cutStartTimer(timer);
    calculate_sigmadlambda<<< grid, threads >>> 
      ( sigmaD, sigmaP/sizeof(T_input), 
	lambdaD, sigmadlambdaD, sdlP/sizeof(T_float32),
	nl);

    cudaThreadSynchronize();
    cutStopTimer(timer);
    DEBUG(1,printf("Processing time: %f (ms) \n\n", cutGetTimerValue(timer)););
    cutDeleteTimer(timer);
    checkerr();
  }

  {    
    dim3 threads(16, 8);
    // thread x is lambda, y is species
    dim3 grid(nelPadded/threads.x, nsPadded/threads.y);
    DEBUG(1,cout << "Calling sigmadlambda kernel (" \
	  << threads.x << ',' << threads.y \
	  << ") (" << grid.x << ',' << grid.y  << ") " << endl;);
    unsigned int timer=0;
    cutCreateTimer(&timer);
    cutStartTimer(timer);
    calculate_sigmadlambda<<< grid, threads >>> 
      ( esigmaD, esigmaP/sizeof(T_input), 
	elambdaD, esigmadlambdaD, esdlP/sizeof(T_float32),
	nel);

    cudaThreadSynchronize();
    cutStopTimer(timer);
    DEBUG(1,printf("Processing time: %f (ms) \n\n", cutGetTimerValue(timer)););
    cutDeleteTimer(timer);
    checkerr();
  }

  // We are now set up to do the calculation. If there does not seem
  // to be enough device memory to fit all the arrays on the device,
  // the calculation is done in blocks.
  const size_t block_size = std::min(int(ceil(1.0*nc/ceil(1.0*estmem/devmem))), 
				     65535*BLOCKSIZE);
  const int n_block = int(ceil(1.0*nc/block_size));

  // each *block* needs to be padded to BLOCKSIZE.
  const size_t bsPadded = (block_size%BLOCKSIZE) ? 
    (block_size/BLOCKSIZE+1)*BLOCKSIZE : block_size;

  cout << "Performing the CUDA calculation in " << n_block << " blocks of " 
       << block_size << " cells.\n" << endl;

  for(size_t i=0; i<nc*nel; ++i) {
    assert(sed[i]==sed[i]);    
    assert(sed[i]<HUGE_VAL);
  }

  for (size_t c=0; c<nc;) {
    size_t cur_block = (nc-c < bsPadded) ? nc-c : bsPadded;
    size_t cur_padded = (cur_block%BLOCKSIZE) ? 
      (cur_block/BLOCKSIZE+1)*BLOCKSIZE : cur_block;

    cout << "Running block of cells " << c << " - " 
	 << c+cur_block-1 << " padded to " 
	 << c+cur_padded-1 << endl;

    // pick out the correct starting points to the arrays for this block
    const T_input* cur_intensity = intensity + nl*c;
    const T_input* cur_mdust = m_dust + c;
    T_input* cur_sed = sed + nel*c;

    // This function calls the computation kernels and copies the
    // intensity and SED arrays back and forth.
    process_block_thermal_equilibrium(sigmaD, sigmaP, 
				      esigmaD, esigmaP, 
				      lambdaD, elambdaD, 
				      sizeD, 
				      sigmadlambdaD, sdlP, 
				      esigmadlambdaD, esdlP,
				      dnD, cur_intensity, cur_mdust, cur_sed,
				      ns, cur_block, nl, nel, nsPadded, cur_padded, 
				      nlPadded, nelPadded,
				      accuracy, 
				      add_to_sed,
				      heating ? heating+c : 0, 
				      temp ? temp+c : 0,
				      nc);
    checkerr();

    c+= cur_block;
  }

  cudaFree(sigmaD);
  cudaFree(esigmaD);
  cudaFree(lambdaD);
  cudaFree(elambdaD);
  cudaFree(dnD);
  cudaFree(sigmadlambdaD);
  cudaFree(esigmadlambdaD);

  cudaThreadSynchronize();
  checkerr();
  cutStopTimer(totalTimer);

  printf("Total GPU processing time: %f (ms) \n\n", cutGetTimerValue(totalTimer));

  if(heating) {
    // apply the missing factor from B_lambda
    for(size_t i=0;i<nc*ns;++i)
      heating[i]*= 1.19104393407e-16f; 
  }

  for(size_t i=0; i<nc*nel; ++i) {
    assert(sed[i]==sed[i]);
    assert(sed[i]<HUGE_VAL);
  }

  return totalTimer;
}




/** This function calls the kernels to do the calculation for one
    block of cells. Some of the input data (those that don't depend on
    the cells) should already have been copied to the device, but the
    intensity and sed arrays are *host* pointers. They are transferred
    by this function.  The sizes and arrays should be padded. If the
    heating and temp (host) pointers are specified, those temporary
    arrays are copied back to host memory for testing purposes. */
template<typename T_input>
unsigned int 
mcrxcuda::
process_block_thermal_equilibrium(const T_input* sigmaD, size_t sigmaP,
				  const T_input* esigmaD, size_t esigmaP,
				  const T_input* lambdaD, 
				  const T_input* elambdaD, 
				  const T_input* sizeD, 
				  const T_float* sigmadlambdaD, size_t sdlP,
				  const T_float* esigmadlambdaD, size_t esdlP,
				  const T_input* dnD,
				  const T_input* intensity, 
				  const T_input* mdust, 
				  T_input* sed,
				  size_t ns, size_t nc, 
				  size_t nl, size_t nel,
				  size_t nsPadded, size_t ncPadded, 
				  size_t nlPadded, size_t nelPadded,
				  T_float accuracy, 
				  bool add_to_sed,
				  T_float* heating, T_float* temp, 
				  size_t ncfull)
{


  bool good=true;

  // Allocate intensity array on device.
  size_t intensityP;
  T_input* intensityD=0;
  cudaMallocPitch((void**)&intensityD, &intensityP, 
		  sizeof(T_input)*nlPadded, ncPadded);
  assert(intensityP%sizeof(T_input)==0);

  // Blank out padding so we can ignore bounds. (Because the loop goes
  // over wavelengths in the padded area, we *need* those to be set to
  // zero to assure that the values in the padding gets set to 0. Do
  // not delete this.)
  if(nlPadded>nl) {
    // this is a little more complicated than it could be because we
    // must align the pointer to the subarray to be blanked by 256 bytes
    size_t subwidth = intensityP-sizeof(T_input)*nl;
    subwidth = ((subwidth>>8)+1)<<8;
    subwidth = min(subwidth, intensityP);
    cudaMemset2D(reinterpret_cast<char*>(intensityD)+intensityP-subwidth,
		 intensityP, 0, 
    		 subwidth, ncPadded); 
  } 
  if(ncPadded>nc) 
    cudaMemset2D(reinterpret_cast<char*>(intensityD)+intensityP*nc, 
		 intensityP, 0, 
		 sizeof(T_input)*nlPadded, ncPadded-nc);

  cudaMemcpy2D(intensityD, intensityP, intensity, 
	       sizeof(T_input)*nl, 
	       sizeof(T_input)*nl, nc, 
	       cudaMemcpyHostToDevice);

  // Allocate heating array on device. We manually assign pitch
  // beacuse the pitch is too large for MallocPitch.
  const size_t heatingP=sizeof(T_float)*ncPadded;
  T_float* heatingD;
  cudaMalloc((void**)&heatingD, sizeof(T_float)*ncPadded*nsPadded);
  assert(heatingP%sizeof(T_float)==0);

  good &= checkerr();

  // Call kernel 2 to calculate heating. this is all done with lambda
  // as wavelength
  {
    dim3 threads(BLOCKSIZE, BLOCKSIZE);
    // thread x is cell, y is species
    assert(ncPadded%BLOCKSIZE==0);
    assert(nsPadded%BLOCKSIZE==0);
    dim3 grid(ncPadded/BLOCKSIZE, nsPadded/BLOCKSIZE);
    const size_t smem=(BLOCKSIZE+1)*(threads.x+threads.y+1)*sizeof(T_float);
    DEBUG(1,cout << "Calling heating kernel (" \
	  << threads.x << ',' << threads.y \
	 << ") (" << grid.x << ',' << grid.y  << ") " \
	  << smem << endl;);

    unsigned int timer=0;
    cutCreateTimer(&timer);
    cutStartTimer(timer);

    // we call the kernel with pitch in float array index units,
    // ie /sizeof(T_float).
    calculate_heating<<< grid, threads, smem >>>
      (heatingD, heatingP/sizeof(T_float),
       // note: sigmaP is the pitch for the sigma array, ie the one with doubles
       sigmadlambdaD, sdlP/sizeof(T_float), 
       intensityD, intensityP/sizeof(T_input),
       nsPadded, ncPadded, nlPadded);

    cudaThreadSynchronize();
    cutStopTimer(timer);
    DEBUG(1,printf("Processing time: %f (ms) \n\n", cutGetTimerValue(timer)););
    cutDeleteTimer(timer);
    good &= checkerr();
  }

  // we are done with intensity now -- free the array
  cudaFree(intensityD);

  if(heating) {
    // copy heating back, if applicable.
    T_float* tempheating = new T_float[ncPadded*nsPadded];
    cudaMemcpy(tempheating, heatingD, sizeof(T_float)*ncPadded*nsPadded,
	       cudaMemcpyDeviceToHost);
    for(size_t s=0; s<ns; ++s) {
      memcpy(reinterpret_cast<void*>(heating+ncfull*s), 
	     reinterpret_cast<void*>(tempheating+ncPadded*s), 
	     sizeof(T_float)*nc);
      for(size_t c=0; c<nc; ++c)
	assert(heating[ncfull*s+c]>0);
    }
    delete[] tempheating;
    good &= checkerr();
  }

  // allocate temp array, manually assigning pitch
  const size_t tempP=sizeof(T_float)*ncPadded;
  T_float* tempD=0;
  cudaMalloc((void**)&tempD, sizeof(T_float)*ncPadded*nsPadded);
  assert(tempP%sizeof(T_float)==0);

  good &= checkerr();

  // now call kernel 3 to calculate temp. this is done with elambda
  {    
    // since block size is 1 species, no need to use padding there BUT
    // we want the temperature to be set to zero in the padded area,
    // so we still call it out to nsPadded. (In that area it will be
    // very fast anyway because all heatings are zero.)  The padding
    // is only to 16 in cell which is way smaller than our block size,
    // so we still need the check in cell dimension, i.e.  no need to
    // pad there either.
    dim3 threads(256, 1);
    // thread x is cell, y is species
    dim3 grid(int(ceil(1.0*nc/threads.x)), nsPadded);
    // we still need to check out of bounds wrt nc here because our
    // block size is so big. And no need to pad in nl because that's
    // the computational loop, not a thread index.
    const int smem=2*nel*sizeof(T_float)+ns*sizeof(T_float);
    DEBUG(1,cout << "Calling temp calc kernel (" \
	  << threads.x << ',' << threads.y \
	 << ") (" << grid.x << ',' << grid.y  << ") " \
	  << smem << endl;);
    unsigned int timer=0;
    cutCreateTimer(&timer);
    cutStartTimer(timer);
    calculate_temp<<< grid, threads, smem >>> ( accuracy, 
						tempD, tempP/sizeof(T_float),
						heatingD, heatingP/sizeof(T_float),
						esigmadlambdaD, esdlP/sizeof(T_float),
						elambdaD, sizeD,
						ns, ncPadded, nel);
    cudaThreadSynchronize();
    cutStopTimer(timer);
    DEBUG(1,printf("Processing time: %f (ms) \n\n", cutGetTimerValue(timer)););
    cutDeleteTimer(timer);
  }

  good &= checkerr();

  // we are done with heating now.
  cudaFree(heatingD);

  if(temp) {
    // copy temp array back, if applicable. The pitch can be too large
    // so we have to do it with a temporary array.
    T_float* temptemp = new T_float[ncPadded*nsPadded];
    cudaMemcpy(temptemp, tempD, sizeof(T_float)*ncPadded*nsPadded, 
	       cudaMemcpyDeviceToHost);
    for(size_t s=0; s<ns; ++s) {
      memcpy(reinterpret_cast<void*>(temp+ncfull*s), 
	     reinterpret_cast<void*>(temptemp+ncPadded*s), 
	     sizeof(T_float)*nc);
    }
    delete[] temptemp;
    for(size_t i=0; i<nc*ns; ++i) {
      assert(temp[i]==temp[i]);
      assert(temp[i]>=0);
      assert(temp[i]<1e4);
    }

    good &= checkerr();
  }


  // allocate SED and Mdust device arrays for sed calculation. 
  T_input* mdustD=0;
  T_input*  sedD=0;
  size_t sedP;

  cudaMalloc((void**)&mdustD, sizeof(T_input)*ncPadded);
  cudaMemcpy(mdustD, mdust, sizeof(T_input)*nc, cudaMemcpyHostToDevice);

  cudaMallocPitch((void**)&sedD, &sedP, sizeof(T_input)*nelPadded, ncPadded);
  // if we are adding to the SED, we need to upload the array, too
  if(add_to_sed) {
    for(size_t i=0; i<nc*nel; ++i) {
      assert(sed[i]==sed[i]);
      assert(sed[i]>=0);
      assert(sed[i]<HUGE_VAL);
    }

    cudaMemcpy2D(sedD, sedP, sed, sizeof(T_input)*nel, 
		 sizeof(T_input)*nel, nc, cudaMemcpyHostToDevice);
  }

  good &= checkerr();
  assert(sedP%sizeof(T_input)==0);

  // Now call kernel 4 to calculate SED. This is done with elambda as
  // the wavelength.

  {    
    dim3 threads(BLOCKSIZE, BLOCKSIZE);
    // thread x is lambda, y is cell
    dim3 grid(nelPadded/threads.x, ncPadded/threads.y);
    const int smem=BLOCKSIZE*(3+2*(BLOCKSIZE+1))*sizeof(T_float);
    DEBUG(1,cout << "Calling sed kernel (" << threads.x << ',' << threads.y \
	 << ") (" << grid.x << ',' << grid.y  << ") " \
	  << smem << endl;);

    unsigned int timer=0;
    cutCreateTimer(&timer);
    cutStartTimer(timer);
    // we call the kernel with pitch in *index* units, ie /sizeof(data).
    calculate_sed<<< grid, threads, smem >>>(sedD, sedP/sizeof(T_input),
					     tempD, tempP/sizeof(T_float),
					     esigmaD, esigmaP/sizeof(T_input),
					     elambdaD, mdustD, dnD,
					     nsPadded, ncPadded, nelPadded,
					     add_to_sed
#ifdef __DEVICE_EMULATION__
					     , nel
#endif
					     );
    cudaThreadSynchronize();
    cutStopTimer(timer);
    DEBUG(1,printf("Processing time: %f (ms) \n\n", cutGetTimerValue(timer)););
    cutDeleteTimer(timer);
  }
  good &= checkerr();

  // copy SED back and free
  cudaMemcpy2D(sed, sizeof(T_input)*nel, sedD, sedP,
	       sizeof(T_input)*nel, nc, cudaMemcpyDeviceToHost);
  for(size_t i=0; i<nc*nel; ++i) {
    assert(sed[i]==sed[i]);
    assert(sed[i]>=0);
    assert(sed[i]<HUGE_VAL);
  }

  cudaFree(mdustD);
  cudaFree(tempD);
  cudaFree(sedD);

  good &= checkerr();
  return good;
}


// Explicit instantiations of the calculation routine for float and
// double alternatives.
template
unsigned int 
mcrxcuda::calculate_equilibrium_SED(const float* sigma, const float* esigma, 
				    const float* intensity, 
				    const float* lambda, const float* elambda, 
				    const float* size,
				    const float* mdust, const float* dn,
				    float* sed,
				    size_t ns, size_t nc, size_t nl, size_t nel, 
				    bool, 
				    float accuracy, float sed_norm,
				    float* heating, float* temp);

template
unsigned int 
mcrxcuda::calculate_equilibrium_SED(const double* sigma, const double* esigma, 
				    const double* intensity, 
				    const double* lambda, const double* elambda, 
				    const double* size,
				    const double* mdust, const double* dn,
				    double* sed,
				    size_t ns, size_t nc, size_t nl, size_t nel, 
				    bool, 
				    float accuracy, float sed_norm,
				    float* heating, float* temp);


bool mcrxcuda::cuda_init(int cuda_dev)
{
  // select device
  int devCnt;
  cudaGetDeviceCount(&devCnt);
  if(devCnt==0) {
    cerr << "No CUDA devices found" << endl;
    return false;
  }
  else
    cout << devCnt << " CUDA devices detected" << endl;
  cudaDeviceProp dp;

  if(cuda_dev>=devCnt) {
    cerr << "CUDA device " << cuda_dev << " does not exist. Using device 0" << endl;
    cuda_dev=0;
  }
  cudaGetDeviceProperties(&dp, 0);
  const size_t devmem = dp.totalGlobalMem;
  cout << "CUDA executing on " << dp.name << " with memory " << devmem/1024/1024 << "MB"  << endl;
  cudaSetDevice(cuda_dev);
  return true;
}
